import h5py
import numpy as np
import os
from pymongo import MongoClient

client = MongoClient(port=27017)
db=client['tiktok']
f = h5py.File('E:\\models\\vae_fnd_0.05_0.3\\weights\\18.hdf5', 'r')
bias=list(f['model_weights']['fnd_output']['fnd_output_1']['bias:0'])[0]
weights=[]
for item in list(f['model_weights']['fnd_output']['fnd_output_1']['kernel:0']):
    weights.extend(item)

k=np.load('E:\\models\\vae_fnd_0.05_0.3\\features\\vae_fnd_d25.npy')
data=[]
with open('E:\\data_pi\\test_ids_d25.txt','r', encoding='utf-8',newline='\n') as fin:
    for line in fin:
        line=line.strip()
        data.append(line.split('\t'))
i=0
for obj in k:
    lb=sum(ai * bi for ai, bi in zip(list(obj), weights))+bias
    lb=1/(1+np.exp(-lb))
    data[i].append(lb)
    if lb >0.5:
        data[i].append(1)
    else:
        data[i].append(0)
    obj1=db[data[i][0]].find_one({'_id':data[i][1]})
    data[i].append(int(obj1['video_feature']['label']['labelA']))
    i+=1
crt=0
fp=0
fn=0
tp=0
tn=0
for res in data:
    print(res)
    if res[3]==res[4]:
        crt+=1
        if res[3]==0:
            tn+=1
        else:
            tp+=1
    else:
        if res[3] == 0:
            fn += 1
        else:
            fp += 1

print(crt/len(data))
#print('precision_1',tp/(tp+fp))
#print('recall_1',tp/(tp+fn))
#print('f1_1',2*(tp/(tp+fp)*tp/(tp+fn))/(tp/(tp+fp)+tp/(tp+fn)))
print('precision_0',tn/(tn+fn))
print('recall_0',tn/(tn+fp))
print('f1_0',2*(tn/(tn+fn)*tn/(tn+fp))/(tn/(tn+fn)+tn/(tn+fp)))
